# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

"""Loss functions."""
import cv2
import numpy as np
import torch
import dnnlib
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import upfirdn2d
from training.dual_discriminator import filtered_resizing
from torchvision.transforms import Resize
#----------------------------------------------------------------------------

class Loss:
    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass
        raise NotImplementedError()

#----------------------------------------------------------------------------

class StyleGAN2Loss(Loss):
    def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'):
        super().__init__()
        self.device             = device
        self.G                  = G
        self.D                  = D
        self.augment_pipe       = augment_pipe
        self.r1_gamma           = r1_gamma
        self.style_mixing_prob  = style_mixing_prob
        self.pl_weight          = pl_weight
        self.pl_batch_shrink    = pl_batch_shrink
        self.pl_decay           = pl_decay
        self.pl_no_weight_grad  = pl_no_weight_grad
        self.pl_mean            = torch.zeros([], device=device)
        self.blur_init_sigma    = blur_init_sigma
        self.blur_fade_kimg     = blur_fade_kimg
        self.r1_gamma_init      = r1_gamma_init
        self.r1_gamma_fade_kimg = r1_gamma_fade_kimg
        self.neural_rendering_resolution_initial = neural_rendering_resolution_initial
        self.neural_rendering_resolution_final = neural_rendering_resolution_final
        self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg
        self.gpc_reg_fade_kimg = gpc_reg_fade_kimg
        self.gpc_reg_prob = gpc_reg_prob
        self.dual_discrimination = dual_discrimination
        self.filter_mode = filter_mode
        self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device)
        self.blur_raw_target = False
        self.l1_loss = torch.nn.L1Loss()
        self.torch_resize = Resize([neural_rendering_resolution_initial, neural_rendering_resolution_initial])
        url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
        with dnnlib.util.open_url(url) as f:
            self.vgg16 = torch.jit.load(f).eval().to(device)
        assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1)

    def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False):
        if swapping_prob is not None:
            c_swapped = torch.roll(c.clone(), 1, 0)
            c_gen_conditioning = torch.where(torch.rand((c.shape[0], 1), device=c.device) < swapping_prob, c_swapped, c)
        else:
            c_gen_conditioning = torch.zeros_like(c)

        ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas)
        gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, update_emas=update_emas)
        return gen_output, ws

    def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False):
        blur_size = np.floor(blur_sigma * 3)
        if blur_size > 0:
            f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2()
            img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum())

        logits = self.D(img, c, update_emas=update_emas)
        return logits

    def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
        assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
        if self.G.rendering_kwargs.get('density_reg', 0) == 0:
            phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
        if self.r1_gamma == 0:
            phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
        blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
        r1_gamma = self.r1_gamma

        alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1
        swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None

        # if self.neural_rendering_resolution_final is not None:
        #     alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1)
        #     neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * (1 - alpha) + self.neural_rendering_resolution_final * alpha))
        # else:
        neural_rendering_resolution = self.neural_rendering_resolution_initial

        real_img_raw = self.torch_resize(real_img)

        if self.blur_raw_target:
            blur_size = np.floor(blur_sigma * 3)
            if blur_size > 0:
                f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div(blur_sigma).square().neg().exp2()
                real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum())

        real_img = {'image': real_img, 'image_raw': real_img_raw}

        # Gmain: Maximize logits for generated images.
        # if phase in ['Gmain', 'Gboth']:
        #     with torch.autograd.profiler.record_function('Gmain_forward'):
        #         gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution)
        #         gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
        #         training_stats.report('Loss/scores/fake', gen_logits)
        #         training_stats.report('Loss/signs/fake', gen_logits.sign())
        #         loss_Gmain = torch.nn.functional.softplus(-gen_logits)
        #         training_stats.report('Loss/G/gan_loss', loss_Gmain)
        #     with torch.autograd.profiler.record_function('Gmain_backward'):
        #         (loss_Gmain.mean()).mul(gain).backward()

        # Minimize reconstrction loss
        if phase in ['Gmain', 'Gboth']:
            for i in self.G.superresolution.parameters():
                i.requires_grad = False
            l1_gen_img, _ = self.run_G(gen_z, real_c, swapping_prob=None, neural_rendering_resolution=neural_rendering_resolution)

            p_real = real_img['image_raw'].clone()
            p_gen = l1_gen_img['image_raw'].clone()
            l1_loss_raw = self.l1_loss(p_real, p_gen)
            target_features = self.vgg16((p_real + 1) * 255 * 0.5, resize_images=True, return_lpips=True)
            synth_features = self.vgg16((p_gen + 1) * 255 * 0.5, resize_images=True, return_lpips=True)
            p_loss_raw = (target_features - synth_features).square().sum()

            p_real = real_img['image'].clone()
            p_gen = l1_gen_img['image'].clone()
            target_features = self.vgg16((p_real + 1) * 255 * 0.5, resize_images=True, return_lpips=True)
            synth_features = self.vgg16((p_gen + 1) * 255 * 0.5, resize_images=True, return_lpips=True)
            p_loss = (target_features - synth_features).square().sum()
            l1_loss = self.l1_loss(p_real, p_gen)
            training_stats.report('Loss/G/l1_loss', l1_loss)
            training_stats.report('Loss/G/l1_loss_raw', l1_loss_raw)
            training_stats.report('Loss/G/percectual_loss', p_loss)
            training_stats.report('Loss/G/percectual_loss_raw', p_loss_raw)
            (p_loss_raw + p_loss + l1_loss_raw + l1_loss).backward()

        # Density Regularization
        # if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'l1':
        #     if swapping_prob is not None:
        #         c_swapped = torch.roll(gen_c.clone(), 1, 0)
        #         c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
        #     else:
        #         c_gen_conditioning = torch.zeros_like(gen_c)

        #     ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
        #     initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
        #     perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist']
        #     all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
        #     sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
        #     sigma_initial = sigma[:, :sigma.shape[1]//2]
        #     sigma_perturbed = sigma[:, sigma.shape[1]//2:]

        #     TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
        #     TVloss.mul(gain).backward()

        # Dmain: Minimize logits for generated images.
        # loss_Dgen = 0
        # if phase in ['Dmain', 'Dboth']:
        #     with torch.autograd.profiler.record_function('Dgen_forward'):
        #         gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution, update_emas=True)
        #         gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
        #         training_stats.report('Loss/scores/fake', gen_logits)
        #         training_stats.report('Loss/signs/fake', gen_logits.sign())
        #         loss_Dgen = torch.nn.functional.softplus(gen_logits)
        #     with torch.autograd.profiler.record_function('Dgen_backward'):
        #         loss_Dgen.mean().mul(gain).backward()

        # # Dmain: Maximize logits for real images.
        # # Dr1: Apply R1 regularization.
        # if phase in ['Dmain', 'Dreg', 'Dboth']:
        #     name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
        #     with torch.autograd.profiler.record_function(name + '_forward'):
        #         real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
        #         real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
        #         real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw}

        #         real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
        #         training_stats.report('Loss/scores/real', real_logits)
        #         training_stats.report('Loss/signs/real', real_logits.sign())

        #         loss_Dreal = 0
        #         if phase in ['Dmain', 'Dboth']:
        #             loss_Dreal = torch.nn.functional.softplus(-real_logits)
        #             training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)

        #         loss_Dr1 = 0
        #         if phase in ['Dreg', 'Dboth']:
        #             if self.dual_discrimination:
        #                 with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
        #                     r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_raw']], create_graph=True, only_inputs=True)
        #                     r1_grads_image = r1_grads[0]
        #                     r1_grads_image_raw = r1_grads[1]
        #                 r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3])
        #             loss_Dr1 = r1_penalty * (r1_gamma / 2)
        #             training_stats.report('Loss/r1_penalty', r1_penalty)
        #             training_stats.report('Loss/D/reg', loss_Dr1)

        #     with torch.autograd.profiler.record_function(name + '_backward'):
        #         (loss_Dreal + loss_Dr1).mean().mul(gain).backward()

#----------------------------------------------------------------------------
